Chapter 4: Geocentric Models

[1]:
%load_ext jupyter_black
[2]:
import jax
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
import numpyro.distributions as dist
import pandas as pd
import plotly
import plotly.graph_objects as go
import plotly.io as pio
from scipy import stats, optimize
from scipy.interpolate import BSpline

pd.options.plotting.backend = "plotly"

seed = 84735
pio.templates.default = "plotly_white"
rng = np.random.default_rng(seed=seed)
jrng = jax.random.key(seed)
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Code

Code 4.1

[3]:
steps = dist.Uniform(low=-1, high=1).sample(jrng, sample_shape=(1_000, 16))
[4]:
steps = -1 + 2 * stats.uniform.rvs(size=(1_000, 16))
[5]:
pd.DataFrame(steps.sum(axis=1)).plot(kind="hist")

Code 4.2

[6]:
steps = dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(12,))
jnp.prod(1 + steps)
[6]:
Array(1.8501737, dtype=float32)

Code 4.3

[7]:
growth = jnp.prod(
    1 + dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
az.plot_density({"growth": growth}, hdi_prob=1)
[7]:
array([[<Axes: title={'center': 'growth'}>]], dtype=object)
../_images/notebooks_04_geocentric_models_10_1.png

Code 4.4

[8]:
big = jnp.prod(
    1 + dist.Uniform(low=0, high=0.5).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
small = jnp.prod(
    1 + dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
az.plot_density({"big": big, "small": small}, hdi_prob=1)
[8]:
array([[<Axes: title={'center': 'big'}>,
        <Axes: title={'center': 'small'}>]], dtype=object)
../_images/notebooks_04_geocentric_models_12_1.png

Code 4.5

[9]:
log_big = jnp.log(
    jnp.prod(
        1 + dist.Uniform(low=0, high=0.5).sample(jrng, sample_shape=(10_000, 12)),
        axis=1,
    )
)
ax = az.plot_density({"log_big": log_big}, hdi_prob=1)
x = jnp.sort(log_big)
gaussian = jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x))
ax[0][0].plot(x, gaussian, "--")
[9]:
[<matplotlib.lines.Line2D at 0x74c913b53740>]
../_images/notebooks_04_geocentric_models_14_1.png

Code 4.6

[10]:
w = 6
n = 9
p_grid = jnp.linspace(0, 1, 100)
posterior = jnp.exp(dist.Binomial(total_count=n, probs=p_grid).log_prob(w)) * jnp.exp(
    dist.Uniform(low=0, high=1).log_prob(p_grid)
)
posterior /= posterior.sum()
pd.DataFrame(posterior, index=p_grid).plot()

Code 4.7

[11]:
df = pd.read_csv("../data/Howell1.csv", sep=";")

Code 4.8

[12]:
df
[12]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041914 41.0 1
4 145.415 41.276872 51.0 0
... ... ... ... ...
539 145.415 31.127751 17.0 1
540 162.560 52.163080 31.0 1
541 156.210 54.062497 21.0 0
542 71.120 8.051258 0.0 1
543 158.750 52.531624 68.0 1

544 rows × 4 columns

Code 4.9

[13]:
df.describe()
[13]:
height weight age male
count 544.000000 544.000000 544.000000 544.000000
mean 138.263596 35.610618 29.344393 0.472426
std 27.602448 14.719178 20.746888 0.499699
min 53.975000 4.252425 0.000000 0.000000
25% 125.095000 22.007717 12.000000 0.000000
50% 148.590000 40.057844 27.000000 0.000000
75% 157.480000 47.209005 43.000000 1.000000
max 179.070000 62.992589 88.000000 1.000000

Code 4.10

[14]:
df["height"]
[14]:
0      151.765
1      139.700
2      136.525
3      156.845
4      145.415
        ...
539    145.415
540    162.560
541    156.210
542     71.120
543    158.750
Name: height, Length: 544, dtype: float64

Code 4.11

[15]:
df2 = df[df["age"] >= 18]

Code 4.12

[16]:
x = jnp.linspace(100, 250)
pd.DataFrame(stats.norm.pdf(x, loc=178, scale=20), index=x).plot()

Code 4.13

[17]:
x = jnp.linspace(-10, 60)
pd.DataFrame(stats.uniform.pdf(x, loc=0, scale=50), index=x).plot()

Code 4.14

[18]:
_, jrng = jax.random.split(jrng)
sample_mu = dist.Normal(loc=178, scale=20).sample(jrng, (10_000,))
_, jrng = jax.random.split(jrng)
sample_sigma = dist.Uniform(low=0, high=50).sample(jrng, (10_000,))
_, jrng = jax.random.split(jrng)
prior_predictive = dist.Normal(loc=sample_mu, scale=sample_sigma).sample(jrng)
az.plot_density({"Prior Predictive Distribution": prior_predictive}, hdi_prob=1)
[18]:
array([[<Axes: title={'center': 'Prior Predictive Distribution'}>]],
      dtype=object)
../_images/notebooks_04_geocentric_models_32_1.png
[19]:
def adult_height_model(priors, heights):
    mu = numpyro.sample(
        "mu", dist.Normal(loc=priors["mu_mean"], scale=priors["mu_scale"])
    )
    sigma = numpyro.sample(
        "sigma", dist.Uniform(low=priors["sigma_low"], high=priors["sigma_high"])
    )
    numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=heights)


prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=10_000)(
    jrng,
    priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
    heights=None,
)
az.plot_density(prior_samples, hdi_prob=1)
[19]:
array([[<Axes: title={'center': 'height'}>,
        <Axes: title={'center': 'mu'}>,
        <Axes: title={'center': 'sigma'}>]], dtype=object)
../_images/notebooks_04_geocentric_models_33_1.png

Code 4.15

[20]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=10_000)(
    jrng,
    priors={"mu_mean": 178, "mu_scale": 100, "sigma_low": 0, "sigma_high": 50},
    heights=None,
)
az.plot_density(prior_samples, hdi_prob=1)
[20]:
array([[<Axes: title={'center': 'height'}>,
        <Axes: title={'center': 'mu'}>,
        <Axes: title={'center': 'sigma'}>]], dtype=object)
../_images/notebooks_04_geocentric_models_35_1.png

Code 4.16

[21]:
mu_list = jnp.linspace(start=150, stop=160, num=100)
sigma_list = jnp.linspace(start=7, stop=9, num=100)
mesh = jnp.meshgrid(mu_list, sigma_list)
posterior = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
posterior["LL"] = jax.vmap(
    lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(df2.height.values))
)(posterior["mu"], posterior["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(posterior["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(posterior["sigma"])
posterior["prob"] = posterior["LL"] + logprob_mu + logprob_sigma
posterior["prob"] = jnp.exp(posterior["prob"] - jnp.max(posterior["prob"]))

Code 4.17

[22]:
plt.contour(
    posterior["mu"].reshape(100, 100),
    posterior["sigma"].reshape(100, 100),
    posterior["prob"].reshape(100, 100),
)
plt.show()
../_images/notebooks_04_geocentric_models_39_0.png

Code 4.18

[23]:
plt.imshow(
    posterior["prob"].reshape(100, 100),
    origin="lower",
    extent=(150, 160, 7, 9),
    aspect="auto",
)
plt.show()
../_images/notebooks_04_geocentric_models_41_0.png

Code 4.19

[24]:
prob = posterior["prob"] / jnp.sum(posterior["prob"])
sample_rows = dist.Categorical(probs=prob).sample(jrng, (int(1e4),))
sample_mu = posterior["mu"][sample_rows]
sample_sigma = posterior["sigma"][sample_rows]
[25]:
pd.DataFrame({"mu": sample_mu, "sigma": sample_sigma}).plot(
    kind="scatter", x="mu", y="sigma", backend="matplotlib", alpha=0.1
)
[25]:
<Axes: xlabel='mu', ylabel='sigma'>
../_images/notebooks_04_geocentric_models_44_1.png

Code 4.20

[26]:
az.plot_kde(sample_mu)
[26]:
<Axes: >
../_images/notebooks_04_geocentric_models_46_1.png
[27]:
az.plot_kde(sample_sigma)
[27]:
<Axes: >
../_images/notebooks_04_geocentric_models_47_1.png

Code 4.22

[28]:
print(f"mu 89% HPDI: {numpyro.diagnostics.hpdi(sample_mu, prob=0.89)}")
print(f"sigma 89% HPDI: {numpyro.diagnostics.hpdi(sample_sigma, prob=0.89)}")
mu 89% HPDI: [154.0404  155.35353]
sigma 89% HPDI: [7.2828283 8.212121 ]

Code 4.23

[29]:
df3 = df2["height"].sample(n=20, random_state=seed)

Code 4.24

[30]:
mu_list = jnp.linspace(start=100, stop=170, num=200)
sigma_list = jnp.linspace(start=4, stop=20, num=200)
mesh = jnp.meshgrid(mu_list, sigma_list)
posterior2 = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
posterior2["LL"] = jax.vmap(
    lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(df3.values))
)(posterior2["mu"], posterior2["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(posterior2["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(posterior2["sigma"])
posterior2["prob"] = posterior2["LL"] + logprob_mu + logprob_sigma
posterior2["prob"] = jnp.exp(posterior2["prob"] - jnp.max(posterior2["prob"]))
prob = posterior2["prob"] / jnp.sum(posterior2["prob"])
sample2_rows = dist.Categorical(probs=prob).sample(jrng, (int(1e4),))
sample2_mu = posterior2["mu"][sample2_rows]
sample2_sigma = posterior2["sigma"][sample2_rows]
plt.scatter(sample2_mu, sample2_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
../_images/notebooks_04_geocentric_models_53_0.png

Code 4.25

[31]:
az.plot_kde(sample2_mu)
x = jnp.sort(sample2_mu)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
../_images/notebooks_04_geocentric_models_55_0.png
[32]:
az.plot_kde(sample2_sigma)
x = jnp.sort(sample2_sigma)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
../_images/notebooks_04_geocentric_models_56_0.png

Code 4.26

[33]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
df = Howell1
df2 = df[df["age"] >= 18]

Code 4.27

[34]:
def adult_height_model(height, priors):
    mu = numpyro.sample(
        "mu", dist.Normal(loc=priors["mu_mean"], scale=priors["mu_scale"])
    )
    sigma = numpyro.sample(
        "sigma", dist.Uniform(low=priors["sigma_low"], high=priors["sigma_high"])
    )
    numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)

Code 4.28

[35]:
adult_height_laplace_model = AutoLaplaceApproximation(adult_height_model)
adult_height_svi = SVI(
    model=adult_height_model,
    guide=adult_height_laplace_model,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    height=df2.height.values,
    priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 5_000)
adult_height_svi.params
100%|██████████████| 5000/5000 [00:01<00:00, 3540.63it/s, init loss: 28358.6328, avg. loss [4751-5000]: 1226.0383]
[35]:
{'auto_loc': Array([154.60709  ,  -1.6973906], dtype=float32)}

Code 4.29

[36]:
_, _jrng = jax.random.split(jrng)
samples = adult_height_laplace_model.sample_posterior(
    _jrng, adult_height_svi.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
        mu    154.61      0.40    154.60    153.98    155.25    977.98      1.00
     sigma      7.74      0.30      7.73      7.31      8.23   1044.95      1.00

Code 4.30

[37]:
init_fn_values = {"mu": df2["height"].mean(), "sigma": df2["height"].std()}
adult_height_laplace_model = AutoLaplaceApproximation(
    adult_height_model, init_loc_fn=numpyro.infer.init_to_value(values=init_fn_values)
)
adult_height_svi = SVI(
    model=adult_height_model,
    guide=adult_height_laplace_model,
    optim=numpyro.optim.Adam(step_size=0.01),
    loss=Trace_ELBO(),
    height=df2.height.values,
    priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 5000)
adult_height_svi.params
100%|███████████████| 5000/5000 [00:01<00:00, 3849.34it/s, init loss: 1226.0387, avg. loss [4751-5000]: 1226.0383]
[37]:
{'auto_loc': Array([154.60728  ,  -1.6970557], dtype=float32)}
[38]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model.sample_posterior(
    _jrng, adult_height_svi.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
        mu    154.59      0.40    154.59    153.99    155.26    965.98      1.00
     sigma      7.76      0.29      7.76      7.28      8.20    892.76      1.00

Code 4.31

[39]:
adult_height_laplace_model_2 = AutoLaplaceApproximation(adult_height_model)
adult_height_svi_2 = SVI(
    model=adult_height_model,
    guide=adult_height_laplace_model_2,
    optim=numpyro.optim.Adam(1),
    loss=Trace_ELBO(),
    height=df2.height.values,
    priors={"mu_mean": 178, "mu_scale": 0.1, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 2000)
adult_height_svi_2.params
100%|████████████| 2000/2000 [00:00<00:00, 2366.00it/s, init loss: 1619491.5000, avg. loss [1901-2000]: 1626.5827]
[39]:
{'auto_loc': Array([ 1.7786377e+02, -3.8493071e-02], dtype=float32)}
[40]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model_2.sample_posterior(
    _jrng, adult_height_svi_2.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
        mu    177.86      0.10    177.86    177.69    178.02    909.28      1.00
     sigma     24.51      0.94     24.48     23.16     26.11    945.10      1.00

Code 4.32

[41]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model.sample_posterior(
    _jrng, adult_height_svi.params, sample_shape=(1000,)
)
vcov = pd.DataFrame(jnp.stack(list(samples.values())), index=["mu", "sigma"]).T.cov()
vcov
[41]:
mu sigma
mu 0.162994 -0.002029
sigma -0.002029 0.086252

Code 4.33

[42]:
print(jnp.diagonal(vcov.values) ** 0.5)
print(
    vcov.values
    / jnp.sqrt(jnp.outer(jnp.diagonal(vcov.values), jnp.diagonal(vcov.values)))
)
[0.40372518 0.2936876 ]
[[ 1.        -0.0171095]
 [-0.0171095  1.       ]]

Code 4.34

[43]:
_, _jrng = jax.random.split(_jrng)
samples = pd.DataFrame(
    adult_height_laplace_model.sample_posterior(
        _jrng, adult_height_svi.params, sample_shape=(10_000,)
    )
)
samples.head()
[43]:
mu sigma
0 154.254120 8.097201
1 154.200668 7.438225
2 153.919067 7.663988
3 154.778809 7.601790
4 154.752045 7.663299

Code 4.35

[44]:
samples.describe()
[44]:
mu sigma
count 10000.000000 10000.000000
mean 154.605774 7.747313
std 0.413330 0.290448
min 152.837646 6.659805
25% 154.328228 7.547918
50% 154.603020 7.743933
75% 154.886684 7.937230
max 156.360703 8.950572

Code 4.36

[45]:
# don't know how to interpret adult_height_svi.params
# feels like it should be the MAP of sigma, but clearly not (it has wrong sign)
# below is kinda dumb workaround (get the samples from .sample_posterior in order to generate samples)
samples = pd.DataFrame(
    adult_height_laplace_model.sample_posterior(
        _jrng, adult_height_svi.params, sample_shape=(10_000,)
    )
)
vcov = samples.cov()
samples = pd.DataFrame(
    dist.MultivariateNormal(
        loc=samples.mean().values, covariance_matrix=vcov.values
    ).sample(_jrng, sample_shape=(10_000,)),
    columns=["mu", "sigma"],
)
samples
[45]:
mu sigma
0 154.251968 8.095943
1 154.198425 7.442693
2 153.916321 7.673726
3 154.777603 7.605191
4 154.750793 7.667237
... ... ...
9995 154.689835 7.479685
9996 154.250900 7.803380
9997 154.720367 7.816793
9998 154.711197 8.014578
9999 154.432938 7.873176

10000 rows × 2 columns

Code 4.37

[46]:
df2.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
[46]:
<Axes: xlabel='weight', ylabel='height'>
../_images/notebooks_04_geocentric_models_82_1.png

Code 4.38

[47]:
_, _jrng = jax.random.split(_jrng)
a = dist.Normal(loc=178, scale=20).sample(jrng, sample_shape=(100,))
_, _jrng = jax.random.split(_jrng)
b = dist.Normal(loc=0, scale=10).sample(_jrng, sample_shape=(100,))
[48]:
def adult_height_model(
    height, weight, *, average_weight, alpha_prior, beta_prior, sigma_prior
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.Normal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    forecast = numpyro.deterministic(
        "forecast", alpha + beta * (weight - average_weight)
    )
    height = numpyro.sample(
        "height", dist.Normal(loc=forecast, scale=sigma), obs=height
    )
    return height
[49]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=100)(
    jrng,
    height=None,
    weight=df2["weight"].values,
    average_weight=df2["weight"].mean(),
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 10},
    sigma_prior={"low": 0, "high": 50},
)

Code 4.39

[50]:
def plot_prior_lines(a, b):
    plt.subplot(
        xlim=(df2.weight.min(), df2.weight.max()),
        ylim=(-100, 400),
        xlabel="weight",
        ylabel="height",
    )
    plt.axhline(y=0, c="k", ls="--")
    plt.axhline(y=272, c="k", ls="-", lw=0.5)
    plt.title("b ~ Normal(0, 10)")
    xbar = df2.weight.mean()
    x = jnp.linspace(df2.weight.min(), df2.weight.max(), 101)
    for i in range(100):
        plt.plot(
            x,
            a[i] + b[i] * (x - xbar),
            "k",
            alpha=0.2,
        )
    plt.show()


plot_prior_lines(a=prior_samples["alpha"], b=prior_samples["beta"])
../_images/notebooks_04_geocentric_models_88_0.png

Code 4.40

[51]:
b = dist.LogNormal(loc=0, scale=1).sample(jrng, sample_shape=(10_000,))
az.plot_kde(b)
[51]:
<Axes: >
../_images/notebooks_04_geocentric_models_90_1.png

Code 4.41

[52]:
def adult_height_model(
    height, weight, *, average_weight, alpha_prior, beta_prior, sigma_prior
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    forecast = numpyro.deterministic(
        "forecast", alpha + beta * (weight - average_weight)
    )
    height = numpyro.sample(
        "height", dist.Normal(loc=forecast, scale=sigma), obs=height
    )
    return height
[53]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=100)(
    jrng,
    height=None,
    weight=df2["weight"].values,
    average_weight=df2["weight"].mean(),
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
)
plot_prior_lines(a=prior_samples["alpha"], b=prior_samples["beta"])
../_images/notebooks_04_geocentric_models_93_0.png

Code 4.42

[54]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
df = Howell1
df2 = df[df["age"] >= 18]
average_weight = df2["weight"].mean()


def adult_height_model(
    weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    # mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
    mu = alpha + beta * (weight - average_weight)
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
    model=adult_height_model,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight=df2["weight"].values,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=df2["height"].values,
).run(jrng, 5_000)

_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████| 5000/5000 [00:01<00:00, 3082.87it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    154.61      0.27    154.62    154.23    155.08   1000.98      1.00
      beta      0.91      0.04      0.91      0.84      0.97   1109.19      1.00
     sigma      5.09      0.20      5.09      4.81      5.43    773.71      1.00

Code 4.43

[55]:
def adult_height_model(
    weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    # mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
    mu = alpha + beta * (weight - average_weight)
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
    model=adult_height_model,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight=df2["weight"].values,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=df2["height"].values,
).run(jrng, 5_000)

_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████| 5000/5000 [00:01<00:00, 3116.18it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    154.60      0.27    154.61    154.20    155.07    916.04      1.00
      beta      0.91      0.04      0.91      0.84      0.97    892.87      1.00
     sigma      5.08      0.19      5.08      4.79      5.40    979.48      1.00

Code 4.44

[56]:
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    154.60      0.27    154.61    154.20    155.07    916.04      1.00
      beta      0.91      0.04      0.91      0.84      0.97    892.87      1.00
     sigma      5.08      0.19      5.08      4.79      5.40    979.48      1.00

Code 4.45

[57]:
pd.DataFrame(posterior_samples).cov().round(3)
[57]:
alpha beta sigma
alpha 0.073 -0.001 -0.002
beta -0.001 0.002 0.000
sigma -0.002 0.000 0.036

Code 4.46

[58]:
fig = pd.DataFrame(df2[["weight", "height"]]).plot(
    kind="scatter",
    x="weight",
    y="height",
)
x = jnp.linspace(df2["weight"].min() * 0.95, df2["weight"].max() * 1.05)
y = posterior_samples["alpha"].mean() + posterior_samples["beta"].mean() * (
    x - average_weight
)
fig.add_trace(go.Scatter(x=x, y=y, name="posterior_mean"))

Code 4.47

[59]:
pd.DataFrame(posterior_samples).head()
[59]:
alpha beta sigma
0 154.845612 0.837958 5.169509
1 154.411041 0.881906 5.131563
2 154.510010 0.946982 5.246259
3 154.651245 0.865017 5.103053
4 154.748108 0.853965 4.906911

Code 4.48

[60]:
def adult_height_model(
    weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    # mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
    mu = alpha + beta * (weight - average_weight)
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
    model=adult_height_model,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight=df2["weight"].values[:10],
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=df2["height"].values[:10],
).run(jrng, 5_000)
100%|██████████████████| 5000/5000 [00:01<00:00, 3437.77it/s, init loss: 194.4541, avg. loss [4751-5000]: 37.0885]

Code 4.49

[61]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(20,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
fig = pd.DataFrame(df2[["weight", "height"]].iloc[:10]).plot(
    kind="scatter",
    x="weight",
    y="height",
)
x = jnp.linspace(df2["weight"].min() * 0.95, df2["weight"].max() * 1.05)
for i in range(20):
    y = posterior_samples["alpha"][i] + posterior_samples["beta"][i] * (
        x - average_weight
    )
    fig.add_trace(
        go.Scatter(x=x, y=y, line={"color": "black"}, opacity=0.3, showlegend=False)
    )
fig

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    152.06      1.38    152.26    150.57    154.32     20.59      0.95
      beta      0.96      0.15      0.92      0.76      1.21     21.22      1.13
     sigma      4.49      1.38      4.37      3.11      5.50     38.15      0.96

Code 4.50

[62]:
def adult_height_model(
    weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
    model=adult_height_model,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight=df2["weight"].values,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=df2["height"].values,
).run(jrng, 5_000)

_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
# numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████| 5000/5000 [00:01<00:00, 3110.76it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]
[63]:
mu_at_50 = posterior_samples["alpha"] + posterior_samples["beta"] * (
    50 - average_weight
)

Code 4.51

[64]:
az.plot_kde(mu_at_50, label="mu|weight=50")
[64]:
<Axes: >
../_images/notebooks_04_geocentric_models_114_1.png

Code 4.52

[65]:
numpyro.diagnostics.hpdi(mu_at_50, prob=0.89)
[65]:
array([158.57683, 159.64386], dtype=float32)

Code 4.53

[66]:
mu = pd.DataFrame(posterior_samples["mu"])
mu.columns.name = "training sample"
mu.index.name = "posterior predictive sample"
mu
[66]:
training sample 0 1 2 3 4 5 6 7 8 9 ... 342 343 344 345 346 347 348 349 350 351
posterior predictive sample
0 157.304321 146.837982 142.572968 162.118820 151.260010 171.303024 148.460266 164.369080 145.346542 163.453278 ... 153.798096 157.356644 149.533066 151.050690 150.841354 156.571671 144.770889 161.307693 163.060791 161.647842
1 156.541153 146.848007 142.898056 161.000000 150.943359 169.505722 148.350449 163.084015 145.466736 162.235870 ... 153.293945 156.589615 149.343994 150.749496 150.555634 155.862625 144.933624 160.248779 161.872375 160.563812
2 157.259293 147.320450 143.270386 161.831146 151.519608 170.552475 148.860977 163.968002 145.904175 163.098343 ... 153.929779 157.308975 149.879700 151.320831 151.122055 156.563568 145.357544 161.060883 162.725647 161.383896
3 157.621613 146.588577 142.092621 162.696808 151.250031 172.378296 148.298706 165.068909 145.016373 164.103516 ... 153.925552 157.676773 149.429581 151.029373 150.808716 156.849304 144.409561 161.841751 163.689789 162.200317
4 157.061279 146.611618 142.353378 161.868118 151.026596 171.037689 148.231308 164.114792 145.122528 163.200455 ... 153.560638 157.113525 149.302399 150.817596 150.608612 156.329803 144.547806 161.058273 162.808594 161.397888
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 157.155518 146.734146 142.487442 161.949341 151.137177 171.094101 148.349457 164.189941 145.249100 163.278076 ... 153.664352 157.207626 149.417648 150.928757 150.720322 156.426025 144.675934 161.141693 162.887268 161.480377
996 156.507141 146.769485 142.801392 160.986465 150.883652 169.531265 148.278824 163.080063 145.381866 162.228027 ... 153.245026 156.555832 149.276932 150.688889 150.494141 155.825516 144.846298 160.231812 161.862869 160.548279
997 157.638733 146.461838 141.907242 162.780121 151.184082 172.587845 148.194260 165.183151 144.869125 164.205170 ... 153.894470 157.694626 149.339890 150.960541 150.737000 156.856354 144.254395 161.913910 163.786041 162.277161
998 157.329163 147.508148 143.506088 161.846832 151.657532 170.464767 149.030411 163.958344 146.108658 163.099014 ... 154.039124 157.378265 150.037064 151.461105 151.264694 156.641693 145.568497 161.085709 162.730728 161.404892
999 156.900024 146.914627 142.845581 161.493301 151.133453 170.255478 148.462357 163.640152 145.491699 162.766434 ... 153.554916 156.949951 149.485870 150.933746 150.734039 156.201035 144.942505 160.719437 162.391983 161.043961

1000 rows × 352 columns

Code 4.54

[67]:
_, _jrng = jax.random.split(_jrng)
weight_seq = jnp.arange(start=25, stop=71, step=1)
mu = numpyro.infer.Predictive(guide.model, posterior_samples, return_sites=["mu"])(
    _jrng,
    weight=weight_seq,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=None,
)["mu"]
mu = pd.DataFrame(mu, columns=pd.Index(weight_seq, name="weight"))
mu.index.name = "posterior predictive sample"
assert (
    posterior_samples["alpha"][0] + posterior_samples["beta"][0] * (25 - average_weight)
    == mu.iat[0, 0]
)
mu
[67]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
posterior predictive sample
0 136.236908 137.159882 138.082840 139.005814 139.928787 140.851761 141.774734 142.697708 143.620682 144.543655 ... 169.463928 170.386902 171.309860 172.232834 173.155807 174.078781 175.001755 175.924728 176.847702 177.770676
1 137.030075 137.884857 138.739655 139.594437 140.449234 141.304016 142.158798 143.013596 143.868378 144.723175 ... 167.802475 168.657272 169.512054 170.366852 171.221634 172.076431 172.931213 173.786011 174.640793 175.495590
2 137.253662 138.130112 139.006577 139.883026 140.759476 141.635941 142.512390 143.388840 144.265305 145.141754 ... 168.806061 169.682510 170.558960 171.435425 172.311874 173.188339 174.064789 174.941238 175.817703 176.694153
3 135.413483 136.386444 137.359390 138.332336 139.305283 140.278229 141.251175 142.224121 143.197067 144.170013 ... 170.439606 171.412552 172.385513 173.358459 174.331406 175.304352 176.277298 177.250244 178.223190 179.196136
4 136.027405 136.948914 137.870407 138.791916 139.713409 140.634918 141.556427 142.477921 143.399429 144.320938 ... 169.201523 170.123016 171.044525 171.966034 172.887527 173.809036 174.730530 175.652039 176.573547 177.495041
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 136.178604 137.097610 138.016617 138.935623 139.854630 140.773636 141.692642 142.611664 143.530670 144.449677 ... 169.262894 170.181900 171.100906 172.019913 172.938919 173.857925 174.776932 175.695953 176.614960 177.533966
996 136.906448 137.765167 138.623871 139.482590 140.341309 141.200027 142.058731 142.917450 143.776169 144.634888 ... 167.820206 168.678925 169.537628 170.396347 171.255066 172.113785 172.972504 173.831207 174.689926 175.548645
997 135.141022 136.126648 137.112289 138.097931 139.083557 140.069199 141.054825 142.040466 143.026108 144.011734 ... 170.623886 171.609528 172.595154 173.580795 174.566422 175.552063 176.537689 177.523331 178.508972 179.494598
998 137.560684 138.426758 139.292816 140.158890 141.024948 141.891022 142.757080 143.623154 144.489212 145.355286 ... 168.739059 169.605118 170.471191 171.337250 172.203323 173.069382 173.935455 174.801514 175.667587 176.533646
999 136.800659 137.681229 138.561783 139.442352 140.322906 141.203476 142.084030 142.964600 143.845154 144.725723 ... 168.500885 169.381439 170.262009 171.142563 172.023132 172.903687 173.784256 174.664810 175.545380 176.425934

1000 rows × 46 columns

Code 4.55

[68]:
df2[["weight", "height"]].plot(kind="scatter", x="weight", y="height", opacity=0)
for i in range(100):
    plt.plot(weight_seq, mu.values[i], "o", c="royalblue", alpha=0.1)
../_images/notebooks_04_geocentric_models_122_0.png

Code 4.56

[69]:
mu_mean = mu.mean().to_frame().T
mu_mean
[69]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
0 136.505142 137.410141 138.315186 139.220139 140.125198 141.030197 141.93512 142.84024 143.745255 144.650284 ... 169.085724 169.990814 170.895859 171.800781 172.705841 173.610886 174.515793 175.420837 176.325867 177.23085

1 rows × 46 columns

[70]:
mu_hpdi = mu.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
mu_hpdi
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.

[70]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
0 135.026993 136.026367 136.997055 137.951721 138.926727 139.908661 140.870804 141.861160 142.846619 143.793472 ... 167.961838 168.830612 169.670883 170.516129 171.353882 172.191635 173.028885 173.860092 174.689926 175.517319
1 137.909683 138.774658 139.614487 140.431335 141.287354 142.156265 142.993195 143.865662 144.713287 145.557938 ... 170.256363 171.241119 172.204147 173.178024 174.145081 175.113068 176.089157 177.051392 178.017044 178.985031

2 rows × 46 columns

Code 4.57

[71]:
ax = df2[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
[71]:
<matplotlib.collections.PolyCollection at 0x74c8e47cbcb0>
../_images/notebooks_04_geocentric_models_127_1.png

Code 4.58

[72]:
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
posterior_samples.pop("mu")
posterior_samples = pd.DataFrame(posterior_samples)


def mu_link(weight):
    return posterior_samples["alpha"] + posterior_samples["beta"] * (
        weight - average_weight
    )


mu = pd.concat([mu_link(_weight) for _weight in weight_seq], axis=1)
mu.columns = pd.Index(weight_seq, name="weight")
mu.index.name = "posterior predictive sample"
mu
[72]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
posterior predictive sample
0 135.985901 136.901230 137.816559 138.731888 139.647217 140.562546 141.477875 142.393219 143.308548 144.223877 ... 168.937790 169.853119 170.768448 171.683777 172.599121 173.514450 174.429779 175.345108 176.260437 177.175766
1 137.509995 138.359192 139.208389 140.057587 140.906784 141.755981 142.605164 143.454361 144.303558 145.152756 ... 168.081024 168.930222 169.779419 170.628616 171.477814 172.327011 173.176208 174.025391 174.874588 175.723785
2 135.088242 136.062073 137.035889 138.009720 138.983551 139.957382 140.931213 141.905029 142.878860 143.852692 ... 170.146027 171.119858 172.093689 173.067520 174.041336 175.015167 175.988998 176.962830 177.936646 178.910477
3 134.468567 135.461731 136.454910 137.448090 138.441254 139.434433 140.427612 141.420792 142.413956 143.407135 ... 170.222870 171.216049 172.209229 173.202393 174.195572 175.188751 176.181915 177.175095 178.168274 179.161453
4 136.359711 137.267639 138.175568 139.083481 139.991409 140.899338 141.807251 142.715179 143.623108 144.531036 ... 169.044968 169.952896 170.860809 171.768738 172.676666 173.584595 174.492508 175.400436 176.308365 177.216278
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 136.797119 137.712357 138.627609 139.542847 140.458099 141.373337 142.288589 143.203827 144.119080 145.034332 ... 169.745956 170.661209 171.576462 172.491699 173.406952 174.322189 175.237442 176.152679 177.067932 177.983170
996 135.942429 136.874466 137.806503 138.738525 139.670563 140.602600 141.534637 142.466660 143.398697 144.330734 ... 169.495636 170.427673 171.359711 172.291733 173.223770 174.155807 175.087845 176.019867 176.951904 177.883942
997 136.452545 137.376831 138.301117 139.225418 140.149719 141.074005 141.998291 142.922592 143.846878 144.771179 ... 169.727097 170.651398 171.575684 172.499985 173.424271 174.348572 175.272858 176.197144 177.121445 178.045746
998 136.993576 137.869659 138.745728 139.621811 140.497894 141.373978 142.250061 143.126144 144.002228 144.878296 ... 168.532501 169.408585 170.284668 171.160751 172.036835 172.912918 173.789001 174.665070 175.541153 176.417236
999 135.966492 136.885849 137.805206 138.724579 139.643936 140.563309 141.482666 142.402039 143.321396 144.240753 ... 169.063583 169.982941 170.902313 171.821671 172.741043 173.660400 174.579773 175.499130 176.418488 177.337860

1000 rows × 46 columns

Code 4.59

[73]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
weight_seq = jnp.arange(start=25, stop=71, step=1)
height = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["height"]
)(
    _jrng,
    weight=weight_seq,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=None,
)[
    "height"
]
height = pd.DataFrame(height, columns=pd.Index(weight_seq, name="weight"))
height.index.name = "posterior predictive height sample"
height
[73]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
posterior predictive height sample
0 136.508286 149.435226 135.722305 149.388138 139.763763 139.949677 140.646118 138.855057 142.752548 143.562958 ... 169.097351 173.265259 164.546005 167.680725 166.906647 164.859055 172.325638 177.131287 174.472702 171.804504
1 137.130280 131.644226 131.078796 146.600143 134.685349 142.405029 145.415344 141.809189 143.663727 150.538666 ... 171.158493 173.759567 166.737335 173.044586 169.006485 182.760620 178.568680 176.966095 173.372452 167.617493
2 140.201584 140.972473 138.505524 138.571991 145.531540 137.499023 151.230240 136.424301 141.453522 144.187225 ... 165.414032 174.831863 168.857559 164.795746 169.127579 177.319336 182.095734 170.876938 182.959732 178.725433
3 137.223373 138.317673 135.861816 141.323700 144.002625 141.397507 146.272064 141.174606 144.210037 150.681000 ... 165.021179 174.404190 171.814453 173.515961 168.846237 175.028915 176.868759 169.214584 163.540894 176.363586
4 134.241409 144.002289 140.749893 137.863724 141.580338 139.624512 135.989700 139.275009 140.649887 138.531845 ... 167.624329 169.881714 172.509033 160.924622 168.618073 169.110428 183.953873 183.501358 181.296814 176.981949
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 141.023254 140.618423 148.049911 143.028687 135.105743 136.967056 142.688721 145.131409 145.675903 143.714188 ... 161.923935 166.959091 167.248596 160.272964 170.994476 170.879868 175.324860 184.061798 176.259384 177.054688
996 136.336578 138.168930 135.906570 144.396240 139.384766 147.582108 133.617355 143.931732 146.498352 142.418411 ... 159.824448 166.302322 163.272964 166.170166 166.100311 174.871994 173.713211 170.594254 167.849503 179.801788
997 128.789352 137.093933 141.240921 131.618515 128.948471 133.299667 146.318787 131.432816 147.747055 143.982483 ... 172.134598 169.760605 170.733414 175.537643 174.765945 180.729172 179.618698 169.870499 170.286469 179.472046
998 126.851830 134.977737 135.001389 131.294464 147.969147 137.543365 146.132797 140.221573 146.279465 139.622467 ... 166.252594 165.666428 172.711197 175.712952 176.767441 178.774521 177.582687 180.221420 176.854004 167.835709
999 135.678772 135.806549 145.103104 138.308670 141.327972 148.940582 133.575974 145.009354 141.294922 147.490875 ... 167.654053 172.131699 167.774033 174.673126 173.745056 175.865540 175.854340 176.483749 174.906326 168.637466

1000 rows × 46 columns

Code 4.60

[74]:
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
height_hpdi
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.

[74]:
weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
0 129.079712 128.893143 129.459946 130.970932 131.758331 132.952377 133.561066 134.395966 136.13765 137.199585 ... 160.902451 161.858826 163.364090 162.511414 165.133896 166.108643 166.287201 167.202789 168.232437 168.767731
1 144.476105 145.375854 145.479507 147.659012 148.286728 149.444717 149.126480 150.059341 151.34671 152.877716 ... 177.465652 178.126465 179.430649 178.957062 181.207123 182.581741 182.563736 182.397293 184.117691 185.513077

2 rows × 46 columns

Code 4.61

[75]:
ax = df2[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
    weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
[75]:
<matplotlib.collections.PolyCollection at 0x74c8c81bb5f0>
../_images/notebooks_04_geocentric_models_135_1.png

Code 4.62

[76]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.arange(start=25, stop=71, step=1)
height = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["height"]
)(
    _jrng,
    weight=weight_seq,
    average_weight=average_weight,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=None,
)[
    "height"
]
height = pd.DataFrame(height, columns=pd.Index(weight_seq, name="weight"))
height.index.name = "posterior predictive height sample"
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
display(height_hpdi)
ax = df2[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
    weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.

weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
0 128.198639 129.154419 130.445114 130.696640 131.987961 133.248535 133.680359 134.982620 135.565125 136.709518 ... 160.880524 161.950745 162.872162 163.943100 164.586792 165.942871 166.361206 167.636047 168.283356 169.236526
1 144.747787 146.053528 146.836273 147.281357 148.310165 149.428192 150.228561 151.407837 151.797745 152.947876 ... 177.590027 178.494415 179.222809 180.173004 180.973145 182.395248 182.904404 183.754227 184.861252 185.728287

2 rows × 46 columns

[76]:
<matplotlib.collections.PolyCollection at 0x74c8c7b26300>
../_images/notebooks_04_geocentric_models_137_3.png

Code 4.63

[77]:
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
posterior_samples.pop("mu")
posterior_samples = pd.DataFrame(posterior_samples)


def sim_height(weight, jrng):
    mu = posterior_samples["alpha"] + posterior_samples["beta"] * (
        weight - average_weight
    )

    return dist.Normal(loc=mu, scale=posterior_samples["sigma"]).sample(
        jrng,
    )


height = []
for _weight in weight_seq:
    _, jrng = jax.random.split(jrng)
    height.append(sim_height(_weight, _jrng))
height = pd.concat(height, axis=1)
height.columns = pd.Index(weight_seq, name="weight")
height.index.name = "posterior predictive sample"
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
display(height_hpdi)
ax = df2[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
    weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.

weight 25 26 27 28 29 30 31 32 33 34 ... 61 62 63 64 65 66 67 68 69 70
0 128.353043 129.231384 130.064392 130.897827 131.913391 132.745377 133.669144 134.482864 135.470245 136.378235 ... 161.073730 161.500427 162.376999 163.222519 164.476730 164.565491 165.759048 166.797089 167.679825 168.559372
1 143.963165 144.840317 145.646332 146.425797 147.388672 148.204834 149.118210 149.980774 150.996918 151.936279 ... 177.229584 177.667191 178.607346 179.467682 180.740768 180.852631 182.037155 183.107285 183.987915 184.907028

2 rows × 46 columns

[77]:
<matplotlib.collections.PolyCollection at 0x74c8c7aaf350>
../_images/notebooks_04_geocentric_models_139_3.png

Code 4.64

[78]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df
[78]:
height weight age male
0 151.765 47.825606 63.0 1
1 139.700 36.485807 63.0 0
2 136.525 31.864838 65.0 0
3 156.845 53.041914 41.0 1
4 145.415 41.276872 51.0 0
... ... ... ... ...
539 145.415 31.127751 17.0 1
540 162.560 52.163080 31.0 1
541 156.210 54.062497 21.0 0
542 71.120 8.051258 0.0 1
543 158.750 52.531624 68.0 1

544 rows × 4 columns

[79]:
df[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib"
)
[79]:
<Axes: xlabel='weight', ylabel='height'>
../_images/notebooks_04_geocentric_models_142_1.png

Code 4.65

[80]:
df["weight_s"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()
df["weight_s2"] = df["weight_s"] ** 2


def m4_5(
    weight_s,
    weight_s2,
    *,
    alpha_prior={"loc": 178, "scale": 20},
    beta1_prior={"loc": 0, "scale": 1},
    beta2_prior={"loc": 0, "scale": 1},
    sigma_prior={"low": 0, "high": 50},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta1 = numpyro.sample("beta1", dist.LogNormal(**beta1_prior))
    beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    mu = numpyro.deterministic("mu", alpha + beta1 * weight_s + beta2 * weight_s2)
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(m4_5)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=m4_5,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight_s=df["weight_s"].values,
    weight_s2=df["weight_s2"].values,
    height=df["height"].values,
).run(jrng, 5_000)
100%|███████████████| 5000/5000 [00:01<00:00, 3204.34it/s, init loss: 8877.6973, avg. loss [4751-5000]: 1770.2781]

Code 4.66

[81]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
_posterior_samples = {
    k: posterior_samples[k] for k in posterior_samples if k not in "mu"
}
numpyro.diagnostics.print_summary(_posterior_samples, 0.89, False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    146.05      0.38    146.06    145.49    146.68    932.87      1.00
     beta1     21.73      0.30     21.72     21.26     22.19    903.86      1.00
     beta2     -7.79      0.28     -7.79     -8.21     -7.30    940.81      1.00
     sigma      5.78      0.19      5.78      5.44      6.02   1084.68      1.00

Code 4.67

[82]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.linspace(start=-2.2, stop=2, num=50)
weight_seq2 = weight_seq**2
posterior_predictive = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["mu", "height"]
)(
    _jrng,
    weight_s=weight_seq,
    weight_s2=weight_seq2,
    height=None,
)
mu_posterior_predictive = pd.DataFrame(posterior_predictive["mu"], columns=weight_seq)
height_posterior_predictive = pd.DataFrame(
    posterior_predictive["height"], columns=weight_seq
)
[83]:
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
display(mu_mean)

mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
display(mu_hpdi)

height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
display(height_mean)

height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
display(height_hpdi)
-2.200000 -2.114286 -2.028572 -1.942857 -1.857143 -1.771429 -1.685714 -1.600000 -1.514286 -1.428572 ... 1.228571 1.314286 1.400000 1.485714 1.571429 1.657143 1.742857 1.828571 1.914286 2.000000
0 60.481586 65.229164 69.861977 74.380379 78.783966 83.07338 87.247482 91.307541 95.253059 99.083488 ... 160.993332 161.156906 161.205719 161.139893 160.959564 160.664368 160.255157 159.731079 159.092514 158.339249

1 rows × 50 columns

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

-2.200000 -2.114286 -2.028572 -1.942857 -1.857143 -1.771429 -1.685714 -1.600000 -1.514286 -1.428572 ... 1.228571 1.314286 1.400000 1.485714 1.571429 1.657143 1.742857 1.828571 1.914286 2.000000
0 59.013569 63.846535 68.564537 73.261055 77.718513 82.131294 86.331566 90.494766 94.491516 98.429581 ... 160.208939 160.259521 160.205215 160.036804 159.673935 159.251678 158.742950 158.031952 157.244263 156.378006
1 61.993671 66.578056 71.061684 75.535278 79.797653 84.034424 88.067047 92.095146 95.975159 99.822418 ... 161.825531 162.069397 162.229767 162.295349 162.181366 162.018936 161.797211 161.378479 160.898743 160.359406

2 rows × 50 columns

-2.200000 -2.114286 -2.028572 -1.942857 -1.857143 -1.771429 -1.685714 -1.600000 -1.514286 -1.428572 ... 1.228571 1.314286 1.400000 1.485714 1.571429 1.657143 1.742857 1.828571 1.914286 2.000000
0 60.416737 65.212372 69.902946 74.325066 78.807297 83.031853 87.275444 91.365974 95.178131 99.095703 ... 161.028168 161.211151 161.167297 161.197403 160.950775 160.647278 160.438293 159.657928 158.956055 158.412537

1 rows × 50 columns

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

-2.200000 -2.114286 -2.028572 -1.942857 -1.857143 -1.771429 -1.685714 -1.600000 -1.514286 -1.428572 ... 1.228571 1.314286 1.400000 1.485714 1.571429 1.657143 1.742857 1.828571 1.914286 2.000000
0 51.016739 55.879940 60.490021 65.458740 69.163567 73.893677 77.600990 82.287376 85.649536 90.373688 ... 151.999893 151.800873 152.276199 151.935120 151.718018 151.222717 150.975449 150.216843 149.342133 148.985733
1 69.423355 74.413803 79.079216 83.902122 87.854408 92.467232 96.193443 100.698784 104.460617 108.844330 ... 170.468155 170.176056 170.734787 170.366165 170.243973 169.753738 169.766205 168.980286 168.354767 167.897156

2 rows × 50 columns

Code 4.68

[84]:
df[["weight_s", "height"]].plot(
    kind="scatter", x="weight_s", y="height", backend="matplotlib"
)
plt.plot(weight_seq, mu_mean.loc[0, :], "k")
plt.fill_between(weight_seq, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weight_seq, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
../_images/notebooks_04_geocentric_models_151_0.png

Code 4.69

[85]:
df["weight_s3"] = df["weight_s"] ** 3


def m4_6(
    weight_s,
    weight_s2,
    weight_s3,
    *,
    alpha_prior={"loc": 178, "scale": 20},
    beta1_prior={"loc": 0, "scale": 1},
    beta2_prior={"loc": 0, "scale": 10},
    beta3_prior={"loc": 0, "scale": 10},
    sigma_prior={"low": 0, "high": 50},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta1 = numpyro.sample("beta1", dist.LogNormal(**beta1_prior))
    beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
    beta3 = numpyro.sample("beta3", dist.Normal(**beta3_prior))
    sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
    mu = numpyro.deterministic(
        "mu", alpha + beta1 * weight_s + beta2 * weight_s2 + beta3 * weight_s3
    )
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height


guide = AutoLaplaceApproximation(m4_6)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=m4_6,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.1),
    loss=Trace_ELBO(),
    weight_s=df["weight_s"].values,
    weight_s2=df["weight_s2"].values,
    weight_s3=df["weight_s3"].values,
    height=df["height"].values,
).run(jrng, 5_000)
100%|███████████████| 5000/5000 [00:01<00:00, 3086.06it/s, init loss: 5829.7710, avg. loss [4751-5000]: 1646.5316]
[86]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.linspace(start=-2.2, stop=2, num=50)
weight_seq2 = weight_seq**2
weight_seq3 = weight_seq**3
posterior_predictive = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["mu", "height"]
)(
    _jrng,
    weight_s=weight_seq,
    weight_s2=weight_seq2,
    weight_s3=weight_seq3,
    height=None,
)
mu_posterior_predictive = pd.DataFrame(posterior_predictive["mu"], columns=weight_seq)
height_posterior_predictive = pd.DataFrame(
    posterior_predictive["height"], columns=weight_seq
)

mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
df[["weight_s", "height"]].plot(
    kind="scatter", x="weight_s", y="height", backend="matplotlib"
)
plt.plot(weight_seq, mu_mean.loc[0, :], "k")
plt.fill_between(weight_seq, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weight_seq, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

../_images/notebooks_04_geocentric_models_154_1.png

Code 4.70

[87]:
df.plot(kind="scatter", x="weight_s", y="height", backend="matplotlib", xticks=[])
[87]:
<Axes: xlabel='weight_s', ylabel='height'>
../_images/notebooks_04_geocentric_models_156_1.png

Code 4.71

[88]:
df.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
[88]:
<Axes: xlabel='weight', ylabel='height'>
../_images/notebooks_04_geocentric_models_158_1.png

Code 4.72

[89]:
df = pd.read_csv("../data/cherry_blossoms.csv", sep=";")
df.describe()
[89]:
year doy temp temp_upper temp_lower
count 1215.000000 827.000000 1124.000000 1124.000000 1124.000000
mean 1408.000000 104.540508 6.141886 7.185151 5.098941
std 350.884596 6.407036 0.663648 0.992921 0.850350
min 801.000000 86.000000 4.670000 5.450000 0.750000
25% 1104.500000 100.000000 5.700000 6.480000 4.610000
50% 1408.000000 105.000000 6.100000 7.040000 5.145000
75% 1711.500000 109.000000 6.530000 7.720000 5.542500
max 2015.000000 124.000000 8.300000 12.100000 7.740000
[90]:
df["temp"].plot(backend="matplotlib")
[90]:
<Axes: >
../_images/notebooks_04_geocentric_models_161_1.png

Code 4.73

[91]:
df2 = df.dropna(subset=["temp"])
num_knots = 15
knots_list = jnp.quantile(
    df["year"].values, jnp.linspace(start=0, stop=1, num=num_knots)
)

Code 4.74

[92]:
degree = 3
knots = jnp.pad(knots_list, (degree, degree), mode="edge")
B = BSpline(knots, jnp.identity(num_knots + 2), k=degree)(df2.year.values)

Code 4.75

[93]:
plt.subplot(
    xlim=(df2.year.min(), df2.year.max()),
    ylim=(0, 1),
    xlabel="year",
    ylabel="basis value",
)
for i in range(B.shape[1]):
    plt.plot(df2.year, B[:, i], "k", alpha=0.5)
../_images/notebooks_04_geocentric_models_167_0.png

Code 4.76

[94]:
def m4_7(
    B,
    *,
    alpha_prior={"loc": 6, "scale": 10},
    weight_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1},
    temp=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    weight = numpyro.sample(
        "weight", dist.Normal(**weight_prior), sample_shape=B.shape[1:]
    )
    mu = numpyro.deterministic("mu", alpha + jnp.dot(B, weight))
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    temp = numpyro.sample("temp", dist.Normal(loc=mu, scale=sigma), obs=temp)
    return temp
[95]:
guide = AutoLaplaceApproximation(
    m4_7, init_loc_fn=numpyro.infer.init_to_value(values={"w": jnp.zeros(B.shape[1])})
)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=m4_7,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.5),
    loss=Trace_ELBO(),
    B=B,
    temp=df2["temp"].values,
).run(jrng, 10_000)
100%|█████████████| 10000/10000 [00:02<00:00, 3872.34it/s, init loss: 5767.3096, avg. loss [9501-10000]: 483.9176]
[96]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha      6.52      0.25      6.52      6.11      6.91   9269.95      1.00
     sigma      0.36      0.01      0.36      0.35      0.37   9550.55      1.00
 weight[0]      2.38      0.86      2.37      1.04      3.74   9323.47      1.00
 weight[1]     -1.61      0.37     -1.61     -2.19     -1.03   9496.14      1.00
 weight[2]      1.72      0.28      1.72      1.26      2.16   9584.92      1.00
 weight[3]     -0.86      0.28     -0.86     -1.32     -0.44   9534.39      1.00
 weight[4]      0.15      0.27      0.15     -0.28      0.58   9329.71      1.00
 weight[5]     -1.84      0.27     -1.84     -2.26     -1.40   9282.52      1.00
 weight[6]      0.95      0.27      0.95      0.53      1.37   9499.81      1.00
 weight[7]     -2.04      0.27     -2.04     -2.47     -1.61   9144.56      1.00
 weight[8]      1.86      0.27      1.86      1.42      2.26   9494.26      1.00
 weight[9]     -2.13      0.27     -2.12     -2.56     -1.70   9301.43      1.00
weight[10]      0.44      0.27      0.45      0.03      0.88   9275.83      1.00
weight[11]     -1.65      0.27     -1.65     -2.10     -1.24   9294.14      1.00
weight[12]     -0.10      0.27     -0.10     -0.52      0.33   9388.49      1.00
weight[13]     -1.48      0.27     -1.48     -1.93     -1.05   9098.48      1.00
weight[14]      0.39      0.28      0.39     -0.05      0.85   9669.49      1.00
weight[15]      1.91      0.34      1.91      1.39      2.47   9564.52      1.00
weight[16]      1.86      0.76      1.86      0.65      3.07   9867.93      1.00

Code 4.77

[97]:
weight_mean = posterior_samples["weight"].mean(axis=0)
plt.subplot(
    xlim=(df2.year.min(), df2.year.max()),
    ylim=(-2, 2),
    xlabel="year",
    ylabel="basis * weight",
)
for i in range(B.shape[1]):
    plt.plot(df2["year"], weight_mean[i] * B[:, i], "k", alpha=0.2)
../_images/notebooks_04_geocentric_models_173_0.png

Code 4.78

[98]:
mu_hpdi = pd.DataFrame(numpyro.diagnostics.hpdi(posterior_samples["mu"], prob=0.89))
df2.plot(kind="scatter", x="year", y="temp", backend="matplotlib", figsize=(10, 6))
plt.fill_between(
    df2["year"], mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5
)
[98]:
<matplotlib.collections.PolyCollection at 0x74c881172fc0>
../_images/notebooks_04_geocentric_models_175_1.png

Code 4.79

[99]:
def m4_7_alt(
    B,
    *,
    alpha_prior={"loc": 6, "scale": 10},
    weight_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1},
    temp=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    weight = numpyro.sample(
        "weight", dist.Normal(**weight_prior), sample_shape=B.shape[1:]
    )
    mu = numpyro.deterministic("mu", alpha + jnp.sum(B * weight, axis=-1))
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    temp = numpyro.sample("temp", dist.Normal(loc=mu, scale=sigma), obs=temp)
    return temp


guide = AutoLaplaceApproximation(
    m4_7_alt,
    init_loc_fn=numpyro.infer.init_to_value(values={"w": jnp.zeros(B.shape[1])}),
)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=m4_7_alt,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.5),
    loss=Trace_ELBO(),
    B=B,
    temp=df2["temp"].values,
).run(jrng, 10_000)

_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)
100%|█████████████| 10000/10000 [00:02<00:00, 3365.08it/s, init loss: 5767.3096, avg. loss [9501-10000]: 483.9176]

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha      6.52      0.25      6.52      6.13      6.92   9591.55      1.00
     sigma      0.36      0.01      0.36      0.35      0.37  10219.07      1.00
 weight[0]      2.39      0.86      2.39      1.12      3.85   9671.32      1.00
 weight[1]     -1.61      0.36     -1.61     -2.17     -1.05   9958.04      1.00
 weight[2]      1.72      0.28      1.72      1.26      2.17   9412.93      1.00
 weight[3]     -0.86      0.27     -0.86     -1.28     -0.42   9821.61      1.00
 weight[4]      0.15      0.27      0.15     -0.26      0.59   9517.43      1.00
 weight[5]     -1.84      0.27     -1.84     -2.28     -1.43   9679.75      1.00
 weight[6]      0.96      0.26      0.95      0.51      1.36   9832.37      1.00
 weight[7]     -2.04      0.27     -2.04     -2.45     -1.60   9217.47      1.00
 weight[8]      1.86      0.26      1.86      1.44      2.28   9860.78      1.00
 weight[9]     -2.13      0.27     -2.13     -2.54     -1.70   9499.83      1.00
weight[10]      0.45      0.26      0.45      0.03      0.86   9563.80      1.00
weight[11]     -1.65      0.27     -1.65     -2.06     -1.21   9835.99      1.00
weight[12]     -0.10      0.27     -0.10     -0.52      0.32   9681.86      1.00
weight[13]     -1.48      0.27     -1.48     -1.91     -1.04   9888.25      1.00
weight[14]      0.39      0.28      0.39     -0.06      0.84   9416.45      1.00
weight[15]      1.91      0.34      1.91      1.38      2.46   9785.94      1.00
weight[16]      1.86      0.76      1.86      0.64      3.05   9780.90      1.00

Easy

4E1

The first line is the likelihood.

4E2

2 parameters, mu and sigma

4E3

\[P(\mu, \sigma | y) = \frac{P(y | \mu, \sigma) \cdot P(\mu) \cdot P(\sigma)}{\int \int P(y | \mu, \sigma) \cdot P(\mu) \cdot P(\sigma) d \mu d \sigma}\]

4E4

The second line is the linear model

4E5

Three parameters, \(\alpha\), \(\beta\) and \(\sigma\)

Medium

4M1

[100]:
def m1(y):
    mu = numpyro.sample("mu", dist.Normal(0, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)


prior_predictive_samples = pd.DataFrame(
    numpyro.infer.Predictive(m1, num_samples=10_000)(jrng, y=None)
)
prior_predictive_samples["y"].plot(kind="kde", backend="matplotlib")
[100]:
<Axes: ylabel='Density'>
../_images/notebooks_04_geocentric_models_186_1.png

4M2

[101]:
def m1(y):
    mu = numpyro.sample("mu", dist.Normal(0, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

4M3

\[\begin{split}\begin{split} y & \sim Normal(\mu, \sigma) \\ \mu & = a + b \cdot x \\ \alpha & \sim Normal(0, 10) \\ \beta & \sim Uniform(0, 1) \\ \sigma & \sim Exponential(1) \end{split}\end{split}\]

4M4

\[\begin{split}\begin{split} d_height & \sim Normal(\mu, \sigma) \\ \mu & = a + b \cdot (year - start year) \\ \alpha & \sim Normal(170, 29) \\ \beta & \sim Normal(0, 1) \\ \sigma & \sim Exponential(1) \end{split}\end{split}\]

4M5

\[\begin{split}\begin{split} d_height & \sim Normal(\mu, \sigma) \\ \mu & = a + b \cdot (year - start year) \\ \alpha & \sim Normal(170, 29) \\ \beta & \sim LogNormal(0, 1) \\ \sigma & \sim Exponential(1) \end{split}\end{split}\]

4M6

We don’t want to change prior by peeking at data.

Hard

4H1

[102]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df["weight_z_score"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()


def h1(
    weight_z_score,
    *,
    alpha_prior={"loc": 170, "scale": 20},
    beta1_prior={"loc": 0, "scale": 1},
    beta2_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1 / 10},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta1 = numpyro.sample("beta1", dist.Normal(**beta1_prior))
    beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
    mu = numpyro.deterministic(
        "mu", alpha + beta1 * weight_z_score + beta2 * weight_z_score**2
    )
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    height_forecast = numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
    return height_forecast
[103]:
# plot 100 random samples from the prior predictive distribution
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
prior_predictive_samples = numpyro.infer.Predictive(h1, num_samples=10_000)(
    jrng, weight_z_score=weight_z_score, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
    n=100, random_state=seed
).iterrows():
    plt.plot(
        weights,
        sample,
        "k",
        alpha=0.2,
    )
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[103]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
../_images/notebooks_04_geocentric_models_196_1.png
[104]:
# plot mean and hpdi of prior predictive distribution
mu_mean = pd.DataFrame(prior_predictive_samples["mu"].mean(axis=0), index=weights)
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(prior_predictive_samples["mu"], prob=0.89),
    columns=weights,
    index=["low", "high"],
)
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(prior_predictive_samples["height"], prob=0.89),
    columns=weights,
    index=["low", "high"],
)

# df.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
plt.plot(weights, mu_mean, "k")
plt.fill_between(weights, mu_hpdi.iloc[0, :], mu_hpdi.iloc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weights, height_hpdi.iloc[0, :], height_hpdi.iloc[1, :], color="k", alpha=0.2
)
# plt.show()
[104]:
<matplotlib.collections.PolyCollection at 0x74c87d5797f0>
../_images/notebooks_04_geocentric_models_197_1.png
[105]:
guide = AutoLaplaceApproximation(h1)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=h1,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.5),
    loss=Trace_ELBO(),
    weight_z_score=df["weight_z_score"].values,
    height=df["height"].values,
).run(jrng, 10_000)
100%|█████████| 10000/10000 [00:02<00:00, 4174.72it/s, init loss: 2903255.0000, avg. loss [9501-10000]: 1981.9767]
[106]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
posterior_predictive = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
    _jrng,
    weight_z_score=weight_z_score,
    height=None,
)

mu_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["mu"], columns=weights
)
height_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["height"], columns=weights
)

mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
df[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weights, mu_mean.loc[0, :], "k")
plt.fill_between(weights, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weights, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

../_images/notebooks_04_geocentric_models_199_1.png
[107]:
weights = jnp.array([46.95, 43.72, 64.68, 32.59, 54.63])
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
posterior_predictive_samples = posterior_predictive(
    _jrng,
    weight_z_score=weight_z_score,
    height=None,
)

mu_posterior_predictive_samples = pd.DataFrame(posterior_predictive_samples["mu"])
height_posterior_predictive_samples = pd.DataFrame(
    posterior_predictive_samples["height"]
)

mu_mean = mu_posterior_predictive_samples.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive_samples, prob=0.89),
)
height_mean = height_posterior_predictive_samples.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive_samples, prob=0.89),
)

pd.concat(
    [
        pd.DataFrame(weights, columns=["weight"]),
        height_mean.T.rename(columns={0: "expected height"}),
        height_hpdi.T.rename(columns={0: "89% hpdi low", 1: "89% hpdi high"}),
    ],
    axis=1,
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

[107]:
weight expected height 89% hpdi low 89% hpdi high
0 46.950001 157.183014 147.768356 166.736542
1 43.720001 155.287338 145.906158 164.688492
2 64.680000 152.824203 142.585907 162.118515
3 32.590000 142.363800 132.973160 151.893936
4 54.630001 158.352539 148.645493 167.660065

4H2

[108]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df = df.loc[df["age"] < 18, :]
df
[108]:
height weight age male
18 121.920 19.617854 12.0 1
19 105.410 13.947954 8.0 0
20 86.360 10.489315 6.5 0
23 129.540 23.586784 13.0 1
24 109.220 15.989118 7.0 0
... ... ... ... ...
535 114.935 17.519991 7.0 1
536 67.945 7.229122 1.0 0
538 76.835 8.022908 1.0 1
539 145.415 31.127751 17.0 1
542 71.120 8.051258 0.0 1

192 rows × 4 columns

[109]:
def h2(
    weight,
    *,
    alpha_prior={"loc": 100, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1 / 10},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    mu = numpyro.deterministic("mu", alpha + beta * weight)
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
[110]:
# plot 100 random samples from the prior predictive distribution
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
prior_predictive_samples = numpyro.infer.Predictive(h2, num_samples=10_000)(
    jrng, weight=weight, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
    n=100, random_state=seed
).iterrows():
    plt.plot(
        weight,
        sample,
        "k",
        alpha=0.2,
    )
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[110]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
../_images/notebooks_04_geocentric_models_204_1.png
[111]:
guide = AutoLaplaceApproximation(h2)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=h2,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.5),
    loss=Trace_ELBO(),
    weight=df["weight"].values,
    height=df["height"].values,
).run(jrng, 10_000)
100%|███████████| 10000/10000 [00:02<00:00, 3916.53it/s, init loss: 232238.2500, avg. loss [9501-10000]: 691.3686]
[112]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha     58.81      1.26     58.80     56.72     60.74   9597.75      1.00
      beta      2.83      0.06      2.83      2.73      2.93   9858.46      1.00
     sigma      8.29      0.44      8.28      7.59      8.98   9915.53      1.00

For every 10 kg increase in weight, we expect 28cm increase in height

[113]:
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
posterior_predictive = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
    _jrng,
    weight=weight,
    height=None,
)

mu_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["mu"], columns=weight
)
height_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["height"], columns=weight
)

mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight
)
df[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weight, mu_mean.loc[0, :], "k")
plt.fill_between(weight, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weight, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

../_images/notebooks_04_geocentric_models_208_1.png

Data appears to have curvature that’s not captured by linear model.

4H3

[114]:
df = pd.read_csv("../data/Howell1.csv", sep=";")


def h3(
    weight,
    *,
    alpha_prior={"loc": 178, "scale": 20},
    beta_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1 / 10},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
    mu = numpyro.deterministic("mu", alpha + beta * jnp.log(weight))
    # mu = numpyro.deterministic("mu", alpha + beta * weight)
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
    return height
[115]:
# plot 100 random samples from the prior predictive distribution
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
prior_predictive_samples = numpyro.infer.Predictive(h3, num_samples=10_000)(
    jrng, weight=weight, height=None
)
mu_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["mu"])
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in mu_prior_predictive_samples.sample(
    n=100, random_state=seed
).iterrows():
    plt.plot(
        weight,
        sample,
        "k",
        alpha=0.2,
    )
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Expected Height Samples")
[115]:
Text(0.5, 1.0, 'Prior Predictive Expected Height Samples')
../_images/notebooks_04_geocentric_models_212_1.png
[116]:
guide = AutoLaplaceApproximation(h3)

_, _jrng = jax.random.split(_jrng)
svi = SVI(
    model=h3,
    guide=guide,
    optim=numpyro.optim.Adam(step_size=0.5),
    loss=Trace_ELBO(),
    weight=df["weight"].values,
    height=df["height"].values,
).run(jrng, 10_000)
100%|█████████| 10000/10000 [00:02<00:00, 3985.92it/s, init loss: 1202857.2500, avg. loss [9501-10000]: 2229.7351]
[117]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)

                mean       std    median      5.5%     94.5%     n_eff     r_hat
     alpha    -18.85      1.38    -18.84    -21.15    -16.72   9436.35      1.00
      beta     45.72      0.39     45.72     45.11     46.37   9430.39      1.00
     sigma      5.24      0.16      5.24      4.98      5.50   9364.90      1.00

[118]:
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
posterior_predictive = numpyro.infer.Predictive(
    guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
    _jrng,
    weight=weight,
    height=None,
)

mu_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["mu"], columns=weight
)
height_posterior_predictive = pd.DataFrame(
    posterior_predictive_samples["height"], columns=weight
)

mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
    numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight
)
df[["weight", "height"]].plot(
    kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weight, mu_mean.loc[0, :], "k")
plt.fill_between(weight, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
    weight, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.

../_images/notebooks_04_geocentric_models_215_1.png

4H4

[119]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df["weight_z_score"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()


def h4(
    weight_z_score,
    *,
    alpha_prior={"loc": 170, "scale": 20},
    beta1_prior={"loc": 0, "scale": 1},
    beta2_prior={"loc": 0, "scale": 1},
    sigma_prior={"rate": 1 / 10},
    height=None,
):
    alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
    beta1 = numpyro.sample("beta1", dist.Normal(**beta1_prior))
    beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
    mu = numpyro.deterministic(
        "mu", alpha + beta1 * weight_z_score + beta2 * weight_z_score**2
    )
    sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
    height_forecast = numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
    return height_forecast
[120]:
# plot 100 random samples from the prior predictive distribution
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
prior_predictive_samples = numpyro.infer.Predictive(h4, num_samples=10_000)(
    jrng, weight_z_score=weight_z_score, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
    n=100, random_state=seed
).iterrows():
    plt.plot(
        weights,
        sample,
        "k",
        alpha=0.2,
    )
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[120]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
../_images/notebooks_04_geocentric_models_218_1.png
[ ]: